{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e73448c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3983584",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch_geometric.datasets as datasets\n",
    "import torch_geometric.data as data\n",
    "import torch_geometric.transforms as transforms\n",
    "import networkx as nx\n",
    "from torch_geometric.utils.convert import to_networkx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7991e79d",
   "metadata": {},
   "source": [
    "## Data Handling in PyG"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6f91aef",
   "metadata": {},
   "source": [
    "### Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cc3dc75",
   "metadata": {},
   "source": [
    "Let's create a dummy graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "681cd953",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = torch.rand((100, 16), dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc81154e",
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = np.random.choice(100, 500)\n",
    "cols = np.random.choice(100, 500)\n",
    "edges = torch.tensor([rows, cols])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd18e171",
   "metadata": {},
   "outputs": [],
   "source": [
    "edges_attr = np.random.choice(3,500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37ec8b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ys = torch.rand((100)).round().long()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8e7de3a",
   "metadata": {},
   "source": [
    "Convert the graph information into a PyG Data object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96b01db9",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = data.Data(x=embeddings, edge_index=edges, edge_attr=edges_attr, y=ys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c09c7db",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75ec4ee0",
   "metadata": {},
   "source": [
    "Let's visualize the information contained in the data object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85efb8c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for prop in graph:\n",
    "    print(prop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21ed3be4",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis = to_networkx(graph)\n",
    "\n",
    "node_labels = graph.y.numpy()\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(1,figsize=(15,13)) \n",
    "nx.draw(vis, cmap=plt.get_cmap('Set3'),node_color = node_labels,node_size=70,linewidths=6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49035bc9",
   "metadata": {},
   "source": [
    "### Batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35c96b81",
   "metadata": {},
   "source": [
    "With the Batch object we can represent multiple graphs as a single disconnected graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab5ec5bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph2 = graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06f22592",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = data.Batch().from_data_list([graph, graph2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd8c5069",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Number of graphs:\",batch.num_graphs)\n",
    "print(\"Graph at index 1:\",batch[1])\n",
    "print(\"Retrieve the list of graphs:\\n\",len(batch.to_data_list()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "061a15d0",
   "metadata": {},
   "source": [
    "### Cluster"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4854f8a",
   "metadata": {},
   "source": [
    "ClusterData groups the nodes of a graph into a specific number of cluster for faster computation in large graphs, then use ClusterLoader to load batches of clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57f64410",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#cluster = data.ClusterData(graph, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a346ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#clusterloader = data.ClusterLoader(cluster)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8cfee13",
   "metadata": {},
   "source": [
    "### Sampler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53ce00ce",
   "metadata": {},
   "source": [
    "For each convolutional layer, sample a maximum of nodes from each neighborhood (as in GraphSAGE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c0ae419",
   "metadata": {},
   "outputs": [],
   "source": [
    "sampler = data.NeighborSampler(graph.edge_index, sizes=[3,10], batch_size=4,\n",
    "                                  shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4679ddc",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for s in sampler:\n",
    "    print(s)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c9cdf37",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Batch size:\", s[0])\n",
    "print(\"Number of unique nodes involved in the sampling:\",len(s[1]))\n",
    "print(\"Number of neighbors sampled:\", len(s[2][0].edge_index[0]), len(s[2][1].edge_index[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eeea37b6",
   "metadata": {},
   "source": [
    "### Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "322ee147",
   "metadata": {},
   "source": [
    "List all the available datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3b49caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets.__all__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea39f999",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = 'Cora'\n",
    "transform = transforms.Compose([\n",
    "    transforms.AddTrainValTestMask('train_rest', num_val=500, num_test=500),\n",
    "    transforms.TargetIndegree(),\n",
    "])\n",
    "cora = datasets.Planetoid('./data', name, pre_transform=transforms.NormalizeFeatures(), transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fca4fd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "aids = datasets.TUDataset(root=\"./data\", name=\"AIDS\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "235ae6d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"AIDS info:\")\n",
    "print('# of graphs:', len(aids))\n",
    "print('# Classes (graphs)', aids.num_classes)\n",
    "print('# Edge features', aids.num_edge_features)\n",
    "print('# Edge labels', aids.num_edge_labels)\n",
    "print('# Node features', aids.num_node_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d1afc18",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Cora info:\")\n",
    "print('# of graphs:', len(cora))\n",
    "print('# Classes (nodes)', cora.num_classes)\n",
    "print('# Edge features', cora.num_edge_features)\n",
    "print('# Node features', cora.num_node_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a1894c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "aids.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fd72167",
   "metadata": {},
   "outputs": [],
   "source": [
    "aids[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "012ff95c",
   "metadata": {},
   "outputs": [],
   "source": [
    "cora.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8582706d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cora[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32a0d0f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "cora_loader = data.DataLoader(cora)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "806906c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for l in cora_loader:\n",
    "    print(l)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "309dd730",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
